import torch
from torch.utils.data import Subset
from typing import List
from sklearn.neighbors import kneighbors_graph
import numpy as np
from tqdm import tqdm
import math
import random
import os
import matplotlib.pyplot as plt

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

def fmt_nf(x: float) -> str:
    """Filesystem-friendly tag for noise factor."""
    return f"{x:.3f}".rstrip("0").rstrip(".").replace(".", "p")

def add_noise(inputs, noise_factor):
    """
    Adds noise to the input images.

    Parameters:
    - inputs: A tensor representing the input image.
    - noise_factor: The standarad deviation of added noise.

    Returns:
    - A tensor representing the noisy image.
    """
    noisy = inputs + torch.randn_like(inputs) * noise_factor
    noisy = torch.clip(noisy, 0., 1.)
    return noisy


def gaussian_kernel(A: torch.Tensor, B: torch.Tensor, sigma: float = 1/(28*28)) -> torch.Tensor:
    """
    Computes the Gaussian kernel between two sets of points.

    Parameters:
    - A (torch.Tensor): A tensor of shape (n, d) representing n points in d-dimensional space.
    - B (torch.Tensor): A tensor of shape (m, d) representing m points in d-dimensional space.
    - sigma (float, optional): The standard deviation of the Gaussian kernel. Default is 1.0.

    Returns:
    - torch.Tensor: A tensor of shape (n, m) representing the Gaussian kernel values between points in A and B.
    """
    diff = A.unsqueeze(1) - B.unsqueeze(0)
    return torch.exp(-sigma * torch.norm(diff, dim=2) ** 2)

def kneighbors_graph_torch(X: torch.Tensor, neighbors: int) -> torch.Tensor:
    """
    Build a k-nearest-neighbor graph in PyTorch (GPU-capable).
    Returns adjacency matrix [batch, batch].
    """
    # Compute pairwise squared Euclidean distance
    dist = torch.cdist(X, X, p=2)   # [batch, batch], GPU accelerated

    # Get indices of k smallest distances (excluding self)
    knn_idx = dist.argsort(dim=1)[:, 1:neighbors+1]

    # Build adjacency matrix
    batch = X.shape[0]
    device = X.device
    N_X = torch.zeros((batch, batch), device=device)

    row_idx = torch.arange(batch, device=device).unsqueeze(1).repeat(1, neighbors)
    N_X[row_idx, knn_idx] = 1.0  # mark neighbors with 1

    return N_X


def ECMMD(Z, Y, X, kernel, neighbors: int):
    batch = X.shape[0]
    device = X.device 

    N_X = kneighbors_graph_torch(X, neighbors)

    # Compute kernel matrices (must be PyTorch ops!)
    kernel_ZZ = kernel(Z, Z)
    kernel_YY = kernel(Y, Y)
    kernel_ZY = kernel(Z, Y)
    kernel_YZ = kernel(Y, Z)

    # Compute H matrix
    H = kernel_ZZ + kernel_YY - kernel_ZY - kernel_YZ

    return torch.sum(H * N_X) / (batch * neighbors)


def split_dataset_by_class(dataset, train_samples_per_class, val_samples_per_class):
    train_indices = []
    val_indices = []
    # Create a dict to hold indices for each of the 10 classes
    class_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
        
    # For each class, shuffle the indices and split into training and validation
    for cls, indices in class_indices.items():
        indices = np.array(indices)
        np.random.shuffle(indices)
        # First part for training, next part for validation
        train_indices.extend(indices[:train_samples_per_class])
        val_indices.extend(indices[train_samples_per_class:train_samples_per_class + val_samples_per_class])
    
    # Optional: shuffle the final list of indices
    train_indices = torch.tensor(np.random.permutation(train_indices))
    val_indices = torch.tensor(np.random.permutation(val_indices))
    
    return Subset(dataset, train_indices), Subset(dataset, val_indices)

def select_samples_by_class(dataset, num_samples_per_class = 500):
    """
    Selects a fixed number of samples per class in MNIST dataset.

    Parameters:
    - dataset: The input MNIST dataset (can be training or testing)
    - num_samples_per_class: Number of samples to be chosen for each class

    Returns:
    - truncated dataset with num_samples_per_class samples in each class
    """
    class_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    selected_indices = []
    for indices in class_indices.values():
        selected_indices.extend(np.random.choice(indices, num_samples_per_class, replace=False))   
    selected_indices = torch.tensor(np.random.permutation(selected_indices))
    return Subset(dataset, selected_indices)

def train_model(model, 
                train_dataloader, 
                validation_dataloader, 
                test_images, 
                noisy_test_images, 
                test_eta, 
                optimizer, 
                ECMMD, 
                gaussian_kernel, 
                NEIGHBORS, 
                NUM_EPOCH, 
                plot_idx=0):
    """
    Trains the model and evaluates on a validation set, plotting test images every 10 epochs.

    Parameters:
        model: The neural network model.
        train_dataloader: DataLoader for the training dataset.
        validation_dataloader: DataLoader for the validation dataset.
        test_images: Ground truth images used for plotting.
        noisy_test_images: Noisy test images used for input to the model.
        test_eta: Eta values associated with the test images.
        optimizer: Optimizer for training.
        ECMMD: A function to compute the error metric.
        gaussian_kernel: Kernel function used in the ECMMD metric.
        NEIGHBORS: Number of neighbors parameter for ECMMD.
        NUM_EPOCH: Total number of training epochs.
        plot_idx: Index of the test image to be plotted (default is 0).

    Returns:
        train_losses: List of average training losses per epoch.
        val_losses: List of average validation losses per epoch.
    """
    device = get_device()
    model.to(device)
    
    train_losses = []
    val_losses = []
    
    for epoch in tqdm(range(NUM_EPOCH)):
        model.train()
        total_train_loss = 0.0
        num_train_batches = 0
        
        for noisy_train_images, train_eta, train_images in train_dataloader:
            optimizer.zero_grad()
            noisy_train_images = noisy_train_images.to(device)
            train_eta = train_eta.to(device)
            train_images = train_images.to(device)
            
            denoised_train_images = model(noisy_train_images, train_eta)
            loss = (ECMMD(
                        denoised_train_images.reshape(len(denoised_train_images), -1),
                        train_images.reshape(len(train_images), -1),
                        noisy_train_images.reshape(len(noisy_train_images), -1),
                        kernel=gaussian_kernel,
                        neighbors=NEIGHBORS
                    ) ** 2)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            num_train_batches += 1
        
        avg_train_loss = total_train_loss / num_train_batches if num_train_batches > 0 else 0
        train_losses.append(avg_train_loss)
        
        model.eval()
        total_val_loss = 0.0
        num_val_batches = 0
        with torch.inference_mode():
            for noisy_validation_images, validation_eta, validation_images in validation_dataloader:
                noisy_validation_images = noisy_validation_images.to(device)
                validation_eta = validation_eta.to(device)
                validation_images = validation_images.to(device)
                
                denoised_validation_images = model(noisy_validation_images, validation_eta)
                val_loss = (ECMMD(
                                denoised_validation_images.reshape(len(denoised_validation_images), -1),
                                validation_images.reshape(len(validation_images), -1),
                                noisy_validation_images.reshape(len(noisy_validation_images), -1),
                                kernel=gaussian_kernel,
                                neighbors=NEIGHBORS
                            ) ** 2)
                total_val_loss += val_loss.item()
                num_val_batches += 1
        
        avg_val_loss = total_val_loss / num_val_batches if num_val_batches > 0 else 0
        val_losses.append(avg_val_loss)
        
        if epoch % 30 == 0:
            print(f'Epoch {epoch}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')
            with torch.no_grad():
                test_noisy = noisy_test_images[plot_idx].unsqueeze(0).to(device)
                test_eta_batch = test_eta[plot_idx].unsqueeze(0).to(device)
                temp_img = model(test_noisy, test_eta_batch).cpu()
                
                plt.figure(figsize=(5, 3))
                plt.subplot(1, 3, 1)
                plt.imshow(test_images[plot_idx].cpu().squeeze(), cmap='gray')
                plt.title('Actual Image')
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(noisy_test_images[plot_idx].cpu().squeeze(), cmap='gray')
                plt.title('Noisy Image')
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(temp_img.cpu().squeeze(), cmap='gray')
                plt.title('Denoised Image')
                plt.axis('off')
                
                plt.show()
    
    return model, train_losses, val_losses



@torch.no_grad()
def plot_denoise_grid(
    *,
    digits: list[int],
    images_per_digit: int,
    model: torch.nn.Module,
    test_images: torch.Tensor,
    noisy_test_images: torch.Tensor,
    test_labels: torch.Tensor,
    device,
    eta_dim: int,
    samples_per_image: int = 100,
    save_path: str | None = None
):
    """
    Creates a 4-row grid:
      Row 0: True Images
      Row 1: Noisy Images
      Row 2: Denoised (Average over 'samples_per_image' draws of η)
      Row 3: Std-dev of generated images across the 'samples_per_image' draws
    """
    model.eval()

    num_images = images_per_digit * len(digits)
    fig, axes = plt.subplots(4, num_images, figsize=(num_images * 1.5, 8))
    if num_images == 1:
        # Keep indexing uniform even if one column
        axes = axes.reshape(4, 1)

    col = 0
    for d in digits:
        idx_for_digit = (test_labels == d).nonzero(as_tuple=True)[0]
        if len(idx_for_digit) == 0:
            continue

        for idx in idx_for_digit[:images_per_digit]:
            i = idx.item() if isinstance(idx, torch.Tensor) else int(idx)

            # Row 0 — True
            axes[0, col].imshow(1 - test_images[i].squeeze(), cmap='gray')
            axes[0, col].axis('off')

            # Row 1 — Noisy
            axes[1, col].imshow(1 - noisy_test_images[i].cpu().squeeze(), cmap='gray')
            axes[1, col].axis('off')

            # Row 2 & 3 — Denoised mean & std across η draws
            outs = []
            noisy_img = noisy_test_images[i].unsqueeze(0).to(device)
            for _ in range(samples_per_image):
                eta = torch.randn(1, eta_dim, eta_dim, device=device)
                out = model(noisy_img, eta)
                outs.append(out)

            outs = torch.stack(outs, dim=0)              # (S, 1, 28, 28) presumably
            denoised_avg = outs.mean(dim=0).cpu()
            denoised_std = outs.std(dim=0).cpu()

            axes[2, col].imshow(1 - denoised_avg.squeeze(), cmap='gray')
            axes[2, col].axis('off')

            axes[3, col].imshow(1 - denoised_std.squeeze(), cmap='gray')
            axes[3, col].axis('off')

            col += 1

    plt.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=300, format='pdf', bbox_inches='tight')
        print(f"Saved figure to: {save_path}")